Skip to content

[triton-raise-block-pointer]: Introduce env. variable to ignore masked load/stores #3416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 13, 2025

Conversation

etiotto
Copy link
Contributor

@etiotto etiotto commented Feb 12, 2025

When tt.load and tt.store operations have a mask the compiler cannot safely "raise" them to use block pointers (block pointers load/stores are unmasked).

This PR introduces a sub option (ignore-masks) for the env. variable TRITON_INTEL_RAISE_BLOCK_POINTER. The suboption allows the compiler to rewrite masked load/stores into unmasked ones, before attempting conversion to block ptr load/stores.

@etiotto
Copy link
Contributor Author

etiotto commented Feb 12, 2025

Note: this PR allows users to assert that the masks can be dropped (i.e. will always evaluate to true). This option is a stop-gap. I plan to work on mask analysis next.

@mfrancepillois
Copy link
Contributor

A next step in mask management could be to check whether masks ‘only’ avoid overflow (and whether they manage data flow) and to set the ‘ignore-mask’ flag if this is the case. I think the boundary control masks should be simple enough to be identified by the pass?

@alexbaden
Copy link
Contributor

I agree that there are broadly two cases of masks to distinguish between, but I don't think we should make that explicit in the code. Do the masks prevent the usage of block pointers or just block loads? Hypothetically, if we can use block ptrs with masked loads and fall back to the gather/scalar load, then we can develop heuristics to determine whether or not a 2D block load w/out masks is safe as an optimization step in the load lowering.

@etiotto
Copy link
Contributor Author

etiotto commented Feb 13, 2025

I agree that there are broadly two cases of masks to distinguish between, but I don't think we should make that explicit in the code. Do the masks prevent the usage of block pointers or just block loads? Hypothetically, if we can use block ptrs with masked loads and fall back to the gather/scalar load, then we can develop heuristics to determine whether or not a 2D block load w/out masks is safe as an optimization step in the load lowering.

The tt.load/tt.store operation do not accept a mask if the ptr operand is a blocked ptr. I modified tutorial 10 as:

    a = tl.load(a_block_ptr, mask=True)

And we then get a compilation error:

ValueError: `mask` and `other` arguments cannot be specified for loading block pointers

We may be able to "bypass" that error if we change the tt.load after the diagnostic fire (that is the user cannot legaly write that code but the compiler could).

The first goal IMO is to determine whether the mask always evaluate to true (or false) for each loop iteration. In that case we can simply remove the mask (if true) or propagate 'other' (if the mask is always false).

@etiotto etiotto marked this pull request as ready for review February 13, 2025 14:02
@etiotto etiotto merged commit e9adc3c into main Feb 13, 2025
5 checks passed
@etiotto etiotto deleted the etiotto.raise_block_ptr.15 branch February 13, 2025 14:54
Comment on lines +581 to +584
.Case<tt::LoadOp>(
[this](auto loadOp) { return IgnoreMasks || !loadOp.getMask(); })
.Case<tt::StoreOp>([this](auto storeOp) {
return IgnoreMasks || !storeOp.getMask();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two cases can likely be combined.

@alexbaden
Copy link
Contributor

The tt.load/tt.store operation do not accept a mask if the ptr operand is a blocked ptr.

This makes sense because masked loads are not a block ptr option in the language - likely b/c computing the mask would be inefficient. If our goal is to use the block ptr machinery (namely 2d block load/store/prefetch) then do we need to internally lower to block ptr? If so we probably should respect the existing language conventions so we don’t diverge from upstream or have nasty surprises later. If we need to support the mask then maybe we need a different lowering strategy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants